/**
* @file allocator.cpp
*
* Copyright (C) Huawei Technologies Co., Ltd. 2023-2023. All Rights Reserved.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/

#include <memory>
#include <mutex>
#include <map>
#include "log_inner.h"
#include "runtime/acl_rt_impl.h"
#include "acl/acl_rt_allocator.h"
#include "toolchain/profiling_manager.h"
#include "toolchain/resource_statistics.h"
#include "framework/memory/allocator_desc.h"

namespace {
class AllocatorDesc {
public:
    AllocatorDesc() = default;
    ~AllocatorDesc() = default;
    AllocatorDesc(aclrtAllocator allocator,
                     aclrtAllocatorAllocFunc allocFunc,
                     aclrtAllocatorFreeFunc freeFunc,
                     aclrtAllocatorAllocAdviseFunc allocAdviseFunc,
                     aclrtAllocatorGetAddrFromBlockFunc getAddrFromBlockFunc):
                     obj_(allocator),
                     allocFunc_(allocFunc),
                     freeFunc_(freeFunc),
                     allocAdviseFunc_(allocAdviseFunc),
                     getAddrFromBlockFunc_(getAddrFromBlockFunc) {}
    aclrtAllocator obj_{ nullptr };
    aclrtAllocatorAllocFunc allocFunc_{ nullptr };
    aclrtAllocatorFreeFunc freeFunc_{ nullptr };
    aclrtAllocatorAllocAdviseFunc allocAdviseFunc_{ nullptr };
    aclrtAllocatorGetAddrFromBlockFunc getAddrFromBlockFunc_{ nullptr };
};
std::mutex g_AllocatorDescMutex;
// The first aclrtAllocatorDesc is created by the user, while the second AllocatorDesc is a saved copy.
std::map<aclrtStream, std::pair<aclrtAllocatorDesc, AllocatorDesc>> g_AllocatorDesMap;
}

aclrtAllocatorDesc aclrtAllocatorCreateDescImpl()
{
    ACL_PROFILING_REG(acl::AclProfType::AclrtAllocatorCreateDesc);
    ACL_ADD_APPLY_TOTAL_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_ALLOCATOR_DESC);
    ACL_LOG_INFO("Create allocator description.");
    AllocatorDesc *allocatorDesc = new(std::nothrow) AllocatorDesc;
    if (allocatorDesc == nullptr) {
        ACL_LOG_INNER_ERROR("alloc AllocatorDesc memory failed");
        return nullptr;
    }
    ACL_ADD_APPLY_SUCCESS_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_ALLOCATOR_DESC);
    return static_cast<aclrtAllocatorDesc>(allocatorDesc);
}

aclError aclrtAllocatorDestroyDescImpl(aclrtAllocatorDesc allocatorDesc)
{
    ACL_PROFILING_REG(acl::AclProfType::AclrtAllocatorDestroyDesc);
    ACL_LOG_INFO("Destroy allocator description, allocatorDesc %p.", allocatorDesc);
    ACL_ADD_RELEASE_TOTAL_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_ALLOCATOR_DESC);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    delete static_cast<AllocatorDesc *>(allocatorDesc);
    allocatorDesc = nullptr;
    ACL_ADD_RELEASE_SUCCESS_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_ALLOCATOR_DESC);
    return ACL_SUCCESS;
}

aclError aclrtAllocatorSetObjToDescImpl(aclrtAllocatorDesc allocatorDesc, aclrtAllocator allocator)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocator);
    ACL_LOG_INFO("Set allocator to allocator description, allocatorDesc %p.", allocatorDesc);
    static_cast<AllocatorDesc *>(allocatorDesc)->obj_ = allocator;
    return ACL_SUCCESS;
}

aclError aclrtAllocatorSetAllocFuncToDescImpl(aclrtAllocatorDesc allocatorDesc, aclrtAllocatorAllocFunc func)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(func);
    ACL_LOG_INFO("Set alloc function to allocator description, allocatorDesc %p.", allocatorDesc);
    static_cast<AllocatorDesc *>(allocatorDesc)->allocFunc_ = func;
    return ACL_SUCCESS;
}

aclError aclrtAllocatorSetFreeFuncToDescImpl(aclrtAllocatorDesc allocatorDesc, aclrtAllocatorFreeFunc func)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(func);
    ACL_LOG_INFO("Set free function to allocator description, allocatorDesc %p.", allocatorDesc);
    static_cast<AllocatorDesc *>(allocatorDesc)->freeFunc_ = func;
    return ACL_SUCCESS;
}

aclError aclrtAllocatorSetAllocAdviseFuncToDescImpl(aclrtAllocatorDesc allocatorDesc, aclrtAllocatorAllocAdviseFunc func)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(func);
    ACL_LOG_INFO("Set alloc advise function to allocator description, allocatorDesc %p.", allocatorDesc);
    static_cast<AllocatorDesc *>(allocatorDesc)->allocAdviseFunc_ = func;
    return ACL_SUCCESS;
}

aclError aclrtAllocatorSetGetAddrFromBlockFuncToDescImpl(aclrtAllocatorDesc allocatorDesc,
                                                     aclrtAllocatorGetAddrFromBlockFunc func)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(func);
    ACL_LOG_INFO("Set get_addr_from_block function to allocator description, allocatorDesc %p.", allocatorDesc);
    static_cast<AllocatorDesc *>(allocatorDesc)->getAddrFromBlockFunc_ = func;
    return ACL_SUCCESS;
}

aclError aclrtAllocatorRegisterImpl(aclrtStream stream, aclrtAllocatorDesc allocatorDesc)
{
    // stream must be not null when register external allocator
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(stream);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);

    AllocatorDesc *allocDesc = static_cast<AllocatorDesc *>(allocatorDesc);
    if (allocDesc->obj_ == nullptr) {
        ACL_LOG_INNER_ERROR("Should call aclrtAllocatorSetObjToDesc first.");
        return ACL_ERROR_INVALID_PARAM;
    }
    if (allocDesc->allocFunc_ == nullptr) {
        ACL_LOG_INNER_ERROR("Should call aclrtAllocatorSetAllocFuncToDesc first.");
        return ACL_ERROR_INVALID_PARAM;
    }
    if (allocDesc->freeFunc_ == nullptr) {
        ACL_LOG_INNER_ERROR("Should call aclrtAllocatorSetFreeFuncToDesc first.");
        return ACL_ERROR_INVALID_PARAM;
    }
    if (allocDesc->getAddrFromBlockFunc_ == nullptr) {
        ACL_LOG_INNER_ERROR("Should call aclrtAllocatorSetGetAddrFromBlockFuncToDesc first.");
        return ACL_ERROR_INVALID_PARAM;
    }
    AllocatorDesc allocDescCopy = AllocatorDesc(allocDesc->obj_,
                                                allocDesc->allocFunc_,
                                                allocDesc->freeFunc_,
                                                allocDesc->allocAdviseFunc_,
                                                allocDesc->getAddrFromBlockFunc_);
    std::pair<aclrtAllocatorDesc, AllocatorDesc> allocatorDescPair(allocatorDesc, allocDescCopy);
    const std::unique_lock<std::mutex> lk(g_AllocatorDescMutex);
    g_AllocatorDesMap[stream] = allocatorDescPair;
    ACL_LOG_INFO("Register external allocator success, stream %p, allocatorDesc %p.", stream, allocatorDesc);
    return ACL_SUCCESS;
}

aclError aclrtAllocatorGetByStreamImpl(aclrtStream stream,
                                   aclrtAllocatorDesc *allocatorDesc,
                                   aclrtAllocator *allocator,
                                   aclrtAllocatorAllocFunc *allocFunc,
                                   aclrtAllocatorFreeFunc *freeFunc,
                                   aclrtAllocatorAllocAdviseFunc *allocAdviseFunc,
                                   aclrtAllocatorGetAddrFromBlockFunc *getAddrFromBlockFunc)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(allocatorDesc);
    const std::unique_lock<std::mutex> lk(g_AllocatorDescMutex);
    const auto iter = g_AllocatorDesMap.find(stream);
    if (iter == g_AllocatorDesMap.end()) {
        return ACL_ERROR_INVALID_PARAM;
    }
    *allocatorDesc = iter->second.first;
    AllocatorDesc &desc = iter->second.second;
    if (allocator != nullptr) {
        *allocator = desc.obj_;
    }
    if (allocFunc != nullptr) {
        *allocFunc = desc.allocFunc_;
    }
    if (freeFunc != nullptr) {
        *freeFunc = desc.freeFunc_;
    }
    if (allocAdviseFunc != nullptr) {
        *allocAdviseFunc = desc.allocAdviseFunc_;
    }
    if (getAddrFromBlockFunc != nullptr) {
        *getAddrFromBlockFunc = desc.getAddrFromBlockFunc_;
    }
    ACL_LOG_INFO("Get allocator By Stream success, stream %p.", stream);
    return ACL_SUCCESS;
}

aclError aclrtAllocatorUnregisterImpl(aclrtStream stream)
{
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(stream);
    const std::unique_lock<std::mutex> lk(g_AllocatorDescMutex);
    g_AllocatorDesMap.erase(stream);
    ACL_LOG_INFO("Unregister external allocator success, stream %p.", stream);
    return ACL_SUCCESS;
}
